import os
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import numpy as np
import time
import random
import math
import copy
from matplotlib import pyplot as plt
import logging
from datetime import datetime
import argparse

from ofa.model_zoo import ofa_net
from ofa.utils import download_url
from ofa.tutorial import AccuracyPredictor, FLOPsTable, EvolutionFinder
from ofa.tutorial import evaluate_ofa_subnet

def setup_logger(log_dir):
    """Setup logging configuration"""
    logger = logging.getLogger('ofa_search')
    logger.setLevel(logging.INFO)
    os.makedirs(log_dir, exist_ok=True)
    handlers = [
        logging.FileHandler(os.path.join(log_dir, 'search_and_eval.log')),
        logging.StreamHandler()
    ]
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    for handler in handlers:
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

def main():
    parser = argparse.ArgumentParser(description='OFA Network Search')
    parser.add_argument('--imagenet-path', type=str, default='datasets/ImageNet',
                        help='Path to ImageNet dataset')
    parser.add_argument('--flops', type=int, default=330,
                        help='FLOPs constraint (M)')
    parser.add_argument('--mutation-ratio', type=float, default=0.5,
                        help='Mutation ratio')
    parser.add_argument('--population-size', type=int, default=50,
                        help='Population size')
    parser.add_argument('--max-time-budget', type=int, default=100,
                        help='Maximum number of iterations')
    parser.add_argument('--parent-ratio', type=float, default=0.2,
                        help='Parent ratio')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random seed')

    args = parser.parse_args()

    log_dir = os.path.join('search_logs',{os.path.basename(__file__).replace('.py', '')}, f'flops{args.flops}_pop{args.population_size}_iter{args.max_time_budget}_seed{args.seed}')
    logger = setup_logger(log_dir)

    # Set random seeds for reproducibility
    SEED = args.seed
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

    # Build validation data transform
    def build_val_transform(size):
        return transforms.Compose([
            transforms.Resize(int(math.ceil(size / 0.875))),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])

    # Initialize data loader
    data_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            root=os.path.join(args.imagenet_path, 'val'),
            transform=build_val_transform(224)
        ),
        batch_size=250,
        shuffle=True,
        num_workers=16,
        pin_memory=True,
        drop_last=False,
    )
    logger.info('The ImageNet dataloader is ready.')

    # Load OFA network
    supernet_name = 'ofa_mbv3_d234_e346_k357_w1.2'
    ofa_network = ofa_net(supernet_name, pretrained=True)
    logger.info(f'The OFA Network is ready. {supernet_name}')

    # Build predictors
    accuracy_predictor = AccuracyPredictor()
    logger.info('The accuracy predictor is ready!')
    flops_lookup_table = FLOPsTable()
    logger.info('The FLOPs lookup table is ready!')

    # Evolution search parameters
    params = {
        'efficiency_constraint': args.flops,  # FLOPs constraint (M)
        'efficiency_predictor': flops_lookup_table,
        'accuracy_predictor': accuracy_predictor,
        'logger': logger,
        'mutation_ratio': args.mutation_ratio,
        'population_size': args.population_size,
        'max_time_budget': args.max_time_budget,
        'parent_ratio': args.parent_ratio,
        'seed': SEED,
    }

    # Log search configuration
    logger.info("Search Configuration:")
    for k, v in vars(args).items():
        logger.info(f"{k}: {v}")

    # finder = EvolutionFinder_visit(**params)
    finder = EvolutionFinder(**params)

    # Start search
    result_lis = []
    for flops in [args.flops]:  # Can set multiple FLOPs constraint points
        st = time.time()
        finder.set_efficiency_constraint(flops)
        best_info = finder.run_evolution_search()
        ed = time.time()
        logger.info(f'Found best architecture at flops <= {flops:.2f}M in {ed-st:.2f} seconds!')
        logger.info(f'It achieves {best_info[0]*100:.2f}% predicted accuracy with {best_info[2]:.2f} MFLOPs.')
        result_lis.append(best_info)

    # Evaluate results
    top1s = []
    flops_lis = []
    for result in result_lis:
        _, net_config, flops, _ = result
        logger.info(f'Evaluating the sub-network {net_config} with FLOPs = {flops:.1f}M')
        top1 = evaluate_ofa_subnet(
            ofa_network,
            args.imagenet_path,
            net_config,
            data_loader)
        logger.info(f'The sub-network {net_config} achieves {top1:.2f}% accuracy')
        logger.info('-' * 45)
        top1s.append(top1)
        flops_lis.append(flops)

if __name__ == '__main__':
    main()